if !(@isdefined run_once) || !run_once

  using Revise

  includet("../Utilities/AlgorithmTools.jl")
  includet("../Utilities/GradientAscent.jl")
  includet("../Utilities/PlottingTools.jl")


  run_once = true

end



using LinearAlgebra
using StaticArrays
using Random
using DynamicPolynomials
using Flux
using ProgressLogging

import StaticPolynomials as SP


using .GradientAscent
using .PlottingTools



const PLAYERS = 2
const ACTIONS = 3

function pseudoGradientOfExpectedPayoffsFunction()

  A = [
     0; -1;  1;;
     1;  0; -1;;
    -1;  1;  0
  ]


  @polyvar x[1:ACTIONS]

  gradient_function = SP.PolynomialSystem(A * x)


  @inline @inbounds function(state)

    SVector{PLAYERS, MVector{ACTIONS}}(
      gradient_function(state[2]),
      gradient_function(state[1])
    )

  end

end

function expectedPayoffsFunction()

  pseudo_gradient_function = pseudoGradientOfExpectedPayoffsFunction()


  @inline @inbounds function(state)

    output = pseudo_gradient_function(state)

    SVector{PLAYERS}(
      state[i] ⋅ output[i]
      for i ∈ 1:PLAYERS
    )

  end

end


function gradientFunction(models, inputs; with_preconditioning=false, convexity_factor=0)

  pseudo_gradient_function = pseudoGradientOfExpectedPayoffsFunction()


  @inline @inbounds function(state)

    output = SVector{ACTIONS}[]
    jacobians = MMatrix{inputs, ACTIONS}[]

    for i ∈ 1:PLAYERS
      
      outputᵢ, jacobianᵢ = Flux.withjacobian(models[i], state[i])
      jacobianᵢ = only(jacobianᵢ)

      if with_preconditioning
        jacobianᵢ = pinv(jacobianᵢ)
      else 
        jacobianᵢ = jacobianᵢ'
      end


      push!(output, outputᵢ)
      push!(jacobians, jacobianᵢ)

    end


    pseudo_gradient = pseudo_gradient_function(output)

    if(convexity_factor > 0)
      for i ∈ 1:PLAYERS
        pseudo_gradient[i] .-= convexity_factor .* (output[i] .- 1 / ACTIONS)
      end
    end

    
    SVector{PLAYERS}(
      SVector{inputs}(jacobians[i] * pseudo_gradient[i])
      for i ∈ 1:PLAYERS
    )

  end

end



function realizeTrajectories(
    models, inputs;
    initial_states,
    iterations=Int(1e3),
    convexity_factor=0.0,
    step=1e-2, 
    with_preconditioning=false,
    with_progress=true,
    kwargs...
  )


  total_trajectories = length(initial_states)


  trajectories = []
  
  for k ∈ axes(initial_states, 1)

    initial_state = initial_states[k]


    println("[$k / $total_trajectories] x₀ ≈ $(repr([
      round.(models[i](initial_state[i]); digits=3)
      for i ∈ 1:PLAYERS
    ]))")


    trajectory = gradientAscent(
      initial_state,
      gradientFunction(
        models, inputs;
        with_preconditioning=with_preconditioning,
        convexity_factor=convexity_factor
      ); 
      iterations=iterations, 
      step=step,
      with_progress=with_progress,
      progress_id=@progressid,
      progress_partition=[k - 1, k] ./ total_trajectories,
      process_name=get(kwargs, :process_name, "Gradient Ascent")
    )

    println("$(get(kwargs, :prompt, "GD")): xₑ ≈ $(repr([
      round.(models[i](trajectory[end].state[i]); digits=3)
      for i ∈ 1:PLAYERS
    ]))")


    push!(trajectories, trajectory)

  end

  @info ProgressLogging.Progress(@progressid; done=true)


  trajectories

end



rng = MersenneTwister(4)
function randomInitialization(sz...; gain=1.0) 

  if isempty(sz)
    return (sz...) -> randomInitialization(sz...; gain=gain)
  end

  2 * gain * (rand(rng, sz...) .- 0.5)

end


inputs = 5
models = SVector{PLAYERS}(
  f64(Chain(
    Dense(inputs => 4, celu; init=randomInitialization(; gain=1.0)),
    Dense(4 => ACTIONS; init=randomInitialization(; gain=1.0)),
    softmax
  ))
  for i ∈ 1:PLAYERS
)

metric = @inline @inbounds function(state)

  norm(reduce(vcat, 
    models[i](state[i]) .- 1 / ACTIONS
    for i ∈ 1:PLAYERS
  ))

end


total_trajectories = 100
drop_out_rate = 0.15
additional_trajectotires = Int(round(total_trajectories * drop_out_rate))

initial_states = sort(
  collect(
    SVector{PLAYERS}(
      MVector{inputs}(randomInitialization(inputs; gain=0.75))
      for _ ∈ 1:PLAYERS
    )
    for _ ∈ 1:(total_trajectories + 2 * additional_trajectotires)
  );
  by=metric, 
  rev=true
)[
  (begin + additional_trajectotires):(end - additional_trajectotires)
]



iterations = Int(1e4)
convexity_factor = 0.2

selected_index = Int(ceil(total_trajectories * 0.1))
y_limits = (0.0, (ACTIONS - 1) / ACTIONS)
plot_type = PlottingTools.PlotTypes.ScatterPlot


trajectories_GD = realizeTrajectories(
  models, inputs;
  initial_states=initial_states,
  iterations=iterations,
  convexity_factor=convexity_factor,
  step=t -> 1e-2
)


filename = "RockPaperScissorsGame_GD"
newFigure(filename) do matlab_session
  plotTrajectoryOfMean(
    matlab_session, trajectories_GD;
    transform=metric, 
    y_limits=y_limits
  )
end

trajectory = trajectories_GD[selected_index]
newFigure(filename; append=true) do matlab_session
  plotTrajectoryInSimplex(
    matlab_session, trajectory;
    transform=state -> (models[1](state[1]))[1:PLAYERS],
    color="'Blue'",
    plot_type=plot_type,
    point_text=["'R'", "'P'", "'S'"]
  )
end

newFigure(filename; append=true) do matlab_session
  plotTrajectoryInSimplex(
    matlab_session, trajectory;
    transform=state -> (models[2](state[2]))[1:PLAYERS],
    color="'Red'",
    plot_type=plot_type,
    point_text=["'R'", "'P'", "'S'"]
  )
end


trajectories_PGD = realizeTrajectories(
  models, inputs;
  initial_states=initial_states,
  iterations=iterations,
  convexity_factor=convexity_factor,
  step=t -> 1e-2,
  with_preconditioning=true,
  process_name="Preconditioning Gradient Ascent",
  prompt="PGD"
)

filename = "RockPaperScissorsGame_PGD"
newFigure(filename) do matlab_session
  plotTrajectoryOfMean(
    matlab_session, trajectories_PGD;
    transform=metric, 
    y_limits=y_limits
  )
end

trajectory = trajectories_PGD[selected_index]
newFigure(filename; append=true) do matlab_session
  plotTrajectoryInSimplex(
    matlab_session, trajectory;
    transform=state -> (models[1](state[1]))[1:PLAYERS],
    color="'Blue'",
    plot_type=plot_type,
    point_text=["'R'", "'P'", "'S'"]
  )
end

newFigure(filename; append=true) do matlab_session
  plotTrajectoryInSimplex(
    matlab_session, trajectory;
    transform=state -> (models[2](state[2]))[1:PLAYERS],
    color="'Red'",
    plot_type=plot_type,
    point_text=["'R'", "'P'", "'S'"]
  )
end


filename = "RockPaperScissorsGame_GDVsPGD"
newFigure(filename) do matlab_session

  plotTrajectoryOfMean(
    matlab_session, trajectories_PGD;
    transform=metric, 
    y_limits=(1e-10, 1e3),
    log_scale=true,
    grid=true,
    color="[$(97 / 255), $(142 / 255), $(46 / 255)]",
    name="'PGD'",
    legends=true
  )

  plotTrajectoryOfMean(
    matlab_session, trajectories_GD;
    transform=metric,
    new_figure=false,
    line_style="'--'",
    color="[$(206 / 255), $(44 / 255), $(64 / 255)]",
    name="'GD'",
    legends=true
  )

end



nothing;